"""Custom decoder definition for Transducer model."""

from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

import torch

from espnet.nets.pytorch_backend.transducer.blocks import build_blocks
from espnet.nets.pytorch_backend.transducer.utils import check_batch_states
from espnet.nets.pytorch_backend.transducer.utils import check_state
from espnet.nets.pytorch_backend.transducer.utils import pad_sequence
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask
from espnet.nets.transducer_decoder_interface import ExtendedHypothesis
from espnet.nets.transducer_decoder_interface import Hypothesis
from espnet.nets.transducer_decoder_interface import TransducerDecoderInterface


class CustomDecoder(TransducerDecoderInterface, torch.nn.Module):
    """Custom decoder module for Transducer model.

    Args:
        odim: Output dimension.
        dec_arch: Decoder block architecture (type and parameters).
        input_layer: Input layer type.
        repeat_block: Number of times dec_arch is repeated.
        joint_activation_type: Type of activation for joint network.
        positional_encoding_type: Positional encoding type.
        positionwise_layer_type: Positionwise layer type.
        positionwise_activation_type: Positionwise activation type.
        input_layer_dropout_rate: Dropout rate for input layer.
        blank_id: Blank symbol ID.

    """

    def __init__(
        self,
        odim: int,
        dec_arch: List,
        input_layer: str = "embed",
        repeat_block: int = 0,
        joint_activation_type: str = "tanh",
        positional_encoding_type: str = "abs_pos",
        positionwise_layer_type: str = "linear",
        positionwise_activation_type: str = "relu",
        input_layer_dropout_rate: float = 0.0,
        blank_id: int = 0,
    ):
        """Construct a CustomDecoder object."""
        torch.nn.Module.__init__(self)

        self.embed, self.decoders, ddim, _ = build_blocks(
            "decoder",
            odim,
            input_layer,
            dec_arch,
            repeat_block=repeat_block,
            positional_encoding_type=positional_encoding_type,
            positionwise_layer_type=positionwise_layer_type,
            positionwise_activation_type=positionwise_activation_type,
            input_layer_dropout_rate=input_layer_dropout_rate,
            padding_idx=blank_id,
        )

        self.after_norm = LayerNorm(ddim)

        self.dlayers = len(self.decoders)
        self.dunits = ddim
        self.odim = odim

        self.blank_id = blank_id

    def set_device(self, device: torch.device):
        """Set GPU device to use.

        Args:
            device: Device ID.

        """
        self.device = device

    def init_state(
        self,
        batch_size: Optional[int] = None,
    ) -> List[Optional[torch.Tensor]]:
        """Initialize decoder states.

        Args:
            batch_size: Batch size.

        Returns:
            state: Initial decoder hidden states. [N x None]

        """
        state = [None] * self.dlayers

        return state

    def forward(
        self, dec_input: torch.Tensor, dec_mask: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encode label ID sequences.

        Args:
            dec_input: Label ID sequences. (B, U)
            dec_mask: Label mask sequences.  (B, U)

        Return:
            dec_output: Decoder output sequences. (B, U, D_dec)
            dec_output_mask: Mask of decoder output sequences. (B, U)

        """
        dec_input = self.embed(dec_input)

        dec_output, dec_mask = self.decoders(dec_input, dec_mask)
        dec_output = self.after_norm(dec_output)

        return dec_output, dec_mask

    def score(
        self, hyp: Hypothesis, cache: Dict[str, Any]
    ) -> Tuple[torch.Tensor, List[Optional[torch.Tensor]], torch.Tensor]:
        """One-step forward hypothesis.

        Args:
            hyp: Hypothesis.
            cache: Pairs of (dec_out, dec_state) for each label sequence. (key)

        Returns:
            dec_out: Decoder output sequence. (1, D_dec)
            dec_state: Decoder hidden states. [N x (1, U, D_dec)]
            lm_label: Label ID for LM. (1,)

        """
        labels = torch.tensor([hyp.yseq], device=self.device)
        lm_label = labels[:, -1]

        str_labels = "_".join(list(map(str, hyp.yseq)))

        if str_labels in cache:
            dec_out, dec_state = cache[str_labels]
        else:
            dec_out_mask = subsequent_mask(len(hyp.yseq)).unsqueeze_(0)

            new_state = check_state(hyp.dec_state, (labels.size(1) - 1), self.blank_id)

            dec_out = self.embed(labels)

            dec_state = []
            for s, decoder in zip(new_state, self.decoders):
                dec_out, dec_out_mask = decoder(dec_out, dec_out_mask, cache=s)
                dec_state.append(dec_out)

            dec_out = self.after_norm(dec_out[:, -1])

            cache[str_labels] = (dec_out, dec_state)

        return dec_out[0], dec_state, lm_label

    def batch_score(
        self,
        hyps: Union[List[Hypothesis], List[ExtendedHypothesis]],
        dec_states: List[Optional[torch.Tensor]],
        cache: Dict[str, Any],
        use_lm: bool,
    ) -> Tuple[torch.Tensor, List[Optional[torch.Tensor]], torch.Tensor]:
        """One-step forward hypotheses.

        Args:
            hyps: Hypotheses.
            dec_states: Decoder hidden states. [N x (B, U, D_dec)]
            cache: Pairs of (h_dec, dec_states) for each label sequences. (keys)
            use_lm: Whether to compute label ID sequences for LM.

        Returns:
            dec_out: Decoder output sequences. (B, D_dec)
            dec_states: Decoder hidden states. [N x (B, U, D_dec)]
            lm_labels: Label ID sequences for LM. (B,)

        """
        final_batch = len(hyps)

        process = []
        done = [None] * final_batch

        for i, hyp in enumerate(hyps):
            str_labels = "_".join(list(map(str, hyp.yseq)))

            if str_labels in cache:
                done[i] = cache[str_labels]
            else:
                process.append((str_labels, hyp.yseq, hyp.dec_state))

        if process:
            labels = pad_sequence([p[1] for p in process], self.blank_id)
            labels = torch.LongTensor(labels, device=self.device)

            p_dec_states = self.create_batch_states(
                self.init_state(),
                [p[2] for p in process],
                labels,
            )

            dec_out = self.embed(labels)

            dec_out_mask = (
                subsequent_mask(labels.size(-1))
                .unsqueeze_(0)
                .expand(len(process), -1, -1)
            )

            new_states = []
            for s, decoder in zip(p_dec_states, self.decoders):
                dec_out, dec_out_mask = decoder(dec_out, dec_out_mask, cache=s)
                new_states.append(dec_out)

            dec_out = self.after_norm(dec_out[:, -1])

        j = 0
        for i in range(final_batch):
            if done[i] is None:
                state = self.select_state(new_states, j)

                done[i] = (dec_out[j], state)
                cache[process[j][0]] = (dec_out[j], state)

                j += 1

        dec_out = torch.stack([d[0] for d in done])
        dec_states = self.create_batch_states(
            dec_states, [d[1] for d in done], [[0] + h.yseq for h in hyps]
        )

        if use_lm:
            lm_labels = torch.LongTensor(
                [hyp.yseq[-1] for hyp in hyps], device=self.device
            )

            return dec_out, dec_states, lm_labels

        return dec_out, dec_states, None

    def select_state(
        self, states: List[Optional[torch.Tensor]], idx: int
    ) -> List[Optional[torch.Tensor]]:
        """Get specified ID state from decoder hidden states.

        Args:
            states: Decoder hidden states. [N x (B, U, D_dec)]
            idx: State ID to extract.

        Returns:
            state_idx: Decoder hidden state for given ID. [N x (1, U, D_dec)]

        """
        if states[0] is None:
            return states

        state_idx = [states[layer][idx] for layer in range(self.dlayers)]

        return state_idx

    def create_batch_states(
        self,
        states: List[Optional[torch.Tensor]],
        new_states: List[Optional[torch.Tensor]],
        check_list: List[List[int]],
    ) -> List[Optional[torch.Tensor]]:
        """Create decoder hidden states sequences.

        Args:
            states: Decoder hidden states. [N x (B, U, D_dec)]
            new_states: Decoder hidden states. [B x [N x (1, U, D_dec)]]
            check_list: Label ID sequences.

        Returns:
            states: New decoder hidden states. [N x (B, U, D_dec)]

        """
        if new_states[0][0] is None:
            return states

        max_len = max(len(elem) for elem in check_list) - 1

        for layer in range(self.dlayers):
            states[layer] = check_batch_states(
                [s[layer] for s in new_states], max_len, self.blank_id
            )

        return states
